#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) 2017-2020  The Project X-Ray Authors.
#
# Use of this source code is governed by a ISC-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/ISC
#
# SPDX-License-Identifier: ISC
'''
This script solves the fuzzing problem through least-mean-square solution of
an overdetermined linear equation system.

The advantages of this method are:
- Ability to detect negative correlations (tags which require clearing bits)
- Can detect partial correlation tag <-> bit. This happens if for a small
  number of specimens a tag is said to be "1" but in fact it is not due to
  the way Vivado interprets requested features and encodes them into bitstream.
- Ease to detect tags with no corresponding bits by evaluating solution error.

The solution is computed using the Tikhonov regularization scheme to ensure
numerical stability. The parameter -a can be used to vary the regularization
factor.

By default each tag is solved separately (best results) while they can be
solved all at once (not recommended).

For each tag a vector of weights is calculated. Each weight corresponds to one
bit. Positive values indicate positive correlation and negative values negative
correlation.

Each weight vector is normalized so that maximum absolute weight is equal to
one.

The parameter -t is used to set threshold for those weights. Weights with
values above the threshold and below the "minus" threshold are output as
candidate bits.

For each weight vector a solution error is computed. If the error exceeds
threshold specified using the -e parameter then the tag is considered to
have no bits.

The option -m can be used to filter bits found for the specified tag in all
other tags. This allows to remove bits from a "IS_BLOCK_IN_USE" type tag from
other tags responsible for enabling other features of that block.
'''
import sys
import os
import argparse
import itertools
import json

import numpy as np
import numpy.linalg as linalg

from prjxray.util import OpenSafeFile

# =============================================================================


def load_data(file_name, tagfilter=lambda tag: True, address_map=None):
    """
    Loads data generated by the segmaker.

    Parameters
    ----------

    file_name:
        Name of the text file with data.
    tagfilter:
        A function for filtering tags. Should reqturn True or False.
    address_map:
        A dict indexed by tuples (address, offset) containing a list
        of tile names.

    Returns
    -------
        A list of dicts. Each contains:
        - "seg": Segment name
        - "bit": A list of bit names
        - "tag": A list of tuples (tag name, tag value)
    """

    segdata = None
    all_segdata = []

    with OpenSafeFile(file_name, "r") as fp:
        for line in fp.readlines():
            line = line.strip()

            # Segment tag
            if line.startswith("seg"):
                fields = line.split()

                if segdata is not None:
                    if len(segdata["tag"]):
                        all_segdata.append(segdata)
                    segdata = None

                segname = fields[1]

                # Map segment address to tile name
                if address_map is not None:
                    address = segname.split("_")
                    address = (
                        int(address[0], base=16),
                        int(address[1]),
                    )
                    if address in address_map:
                        segname = "_or_".join(address_map[address])

                # Append file name
                segname = file_name + ":" + segname

                # Append segdata
                segdata = {"seg": segname, "bit": [], "tag": []}

            if segdata is None:
                continue

            # Bit tag
            if line.startswith("bit"):
                fields = line.split()
                segdata["bit"].append(fields[1])

            # Tag tag
            if line.startswith("tag"):
                fields = line.split()

                if not tagfilter(fields[1]):
                    continue

                segdata["tag"].append((
                    fields[1],
                    int(fields[2]),
                ))

    # Store the last segment if any
    if segdata is not None:
        if len(segdata["tag"]):
            all_segdata.append(segdata)

    return all_segdata


def write_segbits(file_name, all_tags, all_bits, W):
    """
    Writes solution to a raw database file.

    Parameters
    ----------

    file_name:
        Name of the .rdb file.
    all_tags:
        List of considered tags.
    all_bits:
        List of considered bits.
    W:
        Matrix with binary solution.
    """
    lines = []

    for r in range(W.shape[0]):
        bits = []
        for c in range(W.shape[1]):
            w = W[r, c]
            if w < 0:
                bits.append("!" + all_bits[c])
            if w > 0:
                bits.append(all_bits[c])

        if len(bits) == 0:
            bits = ["<0 candidates>"]

        lines.append(all_tags[r] + " " + " ".join(bits) + "\n")

    with OpenSafeFile(file_name, "w") as fp:
        for line in lines:
            fp.write(line)


def dump_results(fp, all_tags, all_bits, W, X, E, tag_stats=None):
    """
    Dumps solution results to an open file in a nice readable format.

    Parameters
    ----------

    fp:
        An open file or stream
    all_tags:
        List of considered tags.
    all_bits:
        List of considered bits.
    W:
        Matrix with binary solution.
    X:
        Matrix with raw solution (floats).
    E:
        Vector with solution errors.
    tag_stats:
        Tag statistics.
    """
    lines = []

    pad_len = max([len(tag) for tag in all_tags])

    skip_bit = []
    for i in range(len(all_bits)):
        skip_bit.append((W[:, i] == 0).all())

    # Bit names
    bit_len = 6
    for i in range(bit_len):
        line = " " * (pad_len + 2 + 3)
        for j in range(len(all_bits)):
            if skip_bit[j]:
                continue
            bname = all_bits[j].ljust(bit_len).replace("_", "|")
            line += bname[i]

        if i == (bit_len - 1):
            if tag_stats is not None:
                line += "  #0   #1 "

        lines.append(line)

    # Tags and bit values
    pad = max([len(tag) for tag in all_tags])

    for r in range(W.shape[0]):
        line = all_tags[r].ljust(pad + 1)

        if (W[r, :] == 0).all():
            line += "(!) "
        else:
            line += "    "

        for c in range(W.shape[1]):
            if skip_bit[c]:
                continue

            b = W[r, c]
            if b < 0:
                line += "0"
            elif b > 0:
                line += "1"
            else:
                line += "-"

        if tag_stats is not None:
            stat = tag_stats[all_tags[r]]
            line += " %4d|%4d" % stat

        x_min = np.min(X[r, :])
        x_max = np.max(X[r, :])
        line += " lo=%+.3f hi=%+.3f e=%.3f" % (x_min, x_max, E[r])

        lines.append(line)

    lines.append("")

    # Write
    for line in lines:
        fp.write(line + "\n")


def dump_solution_to_csv(fp, all_tags, all_bits, X):
    """
    Dumps solution data to CSV.

    Parameters
    ----------

    fp:
        An open file or stream
    all_tags:
        List of considered tags.
    all_bits:
        List of considered bits.
    X:
        Matrix with raw solution (floats).
    """

    # Bits
    line = ","
    for bit in all_bits:
        line += bit + ","
    fp.write(line[:-1] + "\n")

    # Tags + numbers
    for r, tag in enumerate(all_tags):
        line = tag + ","
        for c in range(X.shape[1]):
            line += "%+e," % X[r, c]
        fp.write(line[:-1] + "\n")


def dump_correlation_report(
        fp, all_tags, all_bits, W, C, correlation_exceptions):

    for i, tag in enumerate(all_tags):

        # No exceptions (100% correlation)
        if len(correlation_exceptions[tag]) == 0:
            continue

        fp.write(tag + "\n")

        for j, bit in enumerate(all_bits):

            if bit not in correlation_exceptions[tag]:
                continue

            c = C[i, j]
            w = W[i, j]

            # Dump bit correlation factor
            sgn = "+" if w > 0 else "-"
            fp.write(" bit %s: (%s) %.1f%%\n" % (bit.ljust(6), sgn, c * 100.0))

            # Dump counter-factual cases
            e = correlation_exceptions[tag][bit]
            for x, y, ex in e:
                fp.write("  is %d, should be %d - %s\n" % (x, y, ex))

    fp.write("\n")


# =============================================================================


def build_matrices(all_tags, all_bits, segdata, bias=0.0):
    """
    Builds matrices for the linear equation system to be solved.

    Parameters
    ----------

    all_tags:
        List of considered tags.
    all_bits:
        List of considered bits.
    segdata:
        List of segdata used.
    bias:
        T.B.D.
    """

    M = len(segdata)
    N = len(all_bits)
    K = len(all_tags)

    A = np.zeros((M, N), dtype=np.float64)
    B = np.zeros((M, K), dtype=np.float64)

    # A matrix
    for r, c in itertools.product(range(M), range(N)):
        if all_bits[c] in segdata[r]["bit"]:
            A[r, c] = +1.0
        else:
            A[r, c] = -1.0

    # B matrix
    for r, c in itertools.product(range(M), range(K)):
        for t, x in segdata[r]["tag"]:
            if t == all_tags[c]:
                v = +1.0 if x > 0 else -1.0
                B[r, c] = v + bias

    return A, B


def compute_error(A, B, X):
    """
    Computes solution error.

    Parameters
    ----------

    A:
        Matrix A
    B:
        Matrix B
    X:
        Matrix with computed solution.

    Returns
    -------

    A vector with errors
    """

    K = B.shape[1]

    # Compute error
    Bx = np.matmul(A, X)
    E = np.empty((K))
    for k in range(K):
        E[k] = np.sqrt(np.sum(np.square(Bx[:, k] - B[:, k])))

    return E


# =============================================================================


def solve_lms(all_tags, all_bits, segdata, bias=0.0):
    """
    Solves using direct least square solution (NumPy)

    Parameters
    ----------

    all_tags:
        List of considered tags.
    all_bits:
        List of considered bits.
    segdata:
        List of segdata used.
    bias:
        T.B.D.
    """

    # Build matrices
    A, B = build_matrices(all_tags, all_bits, segdata, bias)

    # Solve
    X, res, r, s = linalg.lstsq(A, B, rcond=None)

    return X, compute_error(A, B, X)


def solve_tichonov(all_tags, all_bits, segdata, bias=0.0, a=0.0):
    """
    Solves using Tichonov regularization method.

    Parameters
    ----------

    all_tags:
        List of considered tags.
    all_bits:
        List of considered bits.
    segdata:
        List of segdata used.
    bias:
        T.B.D.
    a:
        Regularization coefficient.

    Returns
    -------

    Tuple with:
    - Solution matrix X
    - Error vector.

    """

    M = len(segdata)
    N = len(all_bits)
    K = len(all_tags)

    # Build matrices
    A, B = build_matrices(all_tags, all_bits, segdata, bias)

    # Tikhonov regularization
    # https://en.wikipedia.org/wiki/Tikhonov_regularization
    AtA = np.matmul(A.T, A)
    AtB = np.matmul(A.T, B)
    X = np.matmul(np.linalg.inv(AtA + a * np.eye(N)), AtB)

    return X, compute_error(A, B, X)


# =============================================================================


def solve_onebyone(all_tags, all_bits, segdata, solver=solve_lms, **kw):
    """
    Solves each tag separately in one-by-one fashion.

    Parameters
    ----------

    all_tags:
        List of considered tags.
    all_bits:
        List of considered bits.
    segdata:
        List of segdata used.
    solver:
        Solver function.
    **kw:
        Parameters to solver function.

    Returns
    -------

    Tuple with:
    - Solution matrix X
    - Error vector.

    """

    X = np.empty((len(all_bits), len(all_tags)))
    E = np.empty((len(all_tags)))

    for i, tag in enumerate(all_tags):

        tag_segdata = [
            data for data in segdata if tag in [t[0] for t in data["tag"]]
        ]
        print("%s #%d" % (tag, len(tag_segdata)))

        X1, E1 = solver([tag], all_bits, tag_segdata, **kw)

        X[:, i] = X1[:, 0]
        E[i] = E1[0]

    return X, E


# =============================================================================


def detect_candidates(X, th, norm=None):
    """
    Detects candidate bits.

    Parameters
    ----------

    X:
        Matrix with solution
    th:
        Threshold
    norm:
        Normalization scheme. See code.

    Returns
    -------

    A tuple with:
    - Binary solution matrix W
    - Transposed matrix X

    """

    Xt = np.array(X.T)
    W = np.zeros_like(Xt, dtype=int)

    if norm == "max_abs":
        Nv = np.max(np.abs(Xt), axis=1)
        Xt /= np.tile(Nv[:, None], (1, Xt.shape[1]))

    W[Xt < -th] = -1
    W[Xt > +th] = +1

    return W, X.T


# =============================================================================


def compute_bit_correlations(tags_to_solve, bits_to_solve, segdata, W):
    """
    Basing on solution given in the matrix W returns a matrix C with
    correlation coefficients of each bit.

    Also returns a dict of dicts indexed by tag names and bit names with
    correlation exceptions - concrete specimen names where the correlation
    does not occur.
    """

    C = np.zeros_like(W, dtype=float)
    exceptions = {}

    for i, tag in enumerate(tags_to_solve):

        # Filter data for this tag
        tag_segdata = [
            data for data in segdata if tag in [t[0] for t in data["tag"]]
        ]
        exceptions[tag] = {}

        # Compute bit correlation
        for j, bit in enumerate(bits_to_solve):
            w = W[i, j]

            # No correlation with that bit
            if w == 0:
                continue

            corr_sum = 0
            corr_count = 0

            # Compute for one bit
            for k, data in enumerate(tag_segdata):
                bits = data["bit"]

                vt = [v for t, v in data["tag"] if t == tag][0]
                vb = 1 if bit in bits else 0

                # Negative correlation
                if w < 0:
                    vt = int(1 - vt)
                else:
                    vt = int(vt)

                # Correlates
                if vt == vb:
                    corr_sum += 1
                # Does not correlate
                else:
                    if bit not in exceptions[tag]:
                        exceptions[tag][bit] = []
                    exceptions[tag][bit].append((
                        vb,
                        vt,
                        data["seg"],
                    ))

                corr_count += 1

            # Store correlation
            C[i, j] = corr_sum / corr_count

    return C, exceptions


def compute_tag_stats(all_tags, segdata):
    """
    Counts occurrence of all considered tags

    Parameters
    ----------

    all_tags:
        Considered tags
    segdata:
        List of segdata used

    Returns
    -------

    A dict indexed by tag name with tuples containing 0 and 1 occurrence count.

    """

    stats = {}

    for i, tag in enumerate(all_tags):
        count0 = 0
        count1 = 0

        for data in segdata:
            for t, v in data["tag"]:
                if t == tag:
                    if v > 0:
                        count1 += 1
                    else:
                        count0 += 1

        stats[tag] = (
            count0,
            count1,
        )
    return stats


def sort_bits(bit_name):
    """
    Utility function for sorting bits.
    """

    frm, ofs = bit_name.split("_")
    return (
        int(frm),
        int(ofs),
    )


def build_address_map(tilegrid_file):
    """
    Loads the tilegrid and generates a map (baseaddr, offset) -> tile name(s).

    Parameters
    ----------

    tilegrid_file:
        The tilegrid.json file/

    Returns
    -------

    A dict with lists of tile names.

    """

    address_map = {}

    # Load tilegrid
    with OpenSafeFile(tilegrid_file, "r") as fp:
        tilegrid = json.load(fp)

    # Loop over tiles
    for tile_name, tile_data in tilegrid.items():

        # No bits or bits empty
        if "bits" not in tile_data:
            continue
        if not len(tile_data["bits"]):
            continue

        bits = tile_data["bits"]

        # No bus
        if "CLB_IO_CLK" not in bits:
            continue

        bus = bits["CLB_IO_CLK"]

        # Make the address as integers
        baseaddr = int(bus["baseaddr"], 16)
        offset = int(bus["offset"])
        address = (
            baseaddr,
            offset,
        )

        # Add tile to the map
        if address not in address_map:
            address_map[address] = []
        address_map[address].append(tile_name)

    return address_map


# =============================================================================


class FileOrStream(object):
    def __init__(self, file_name, stream=sys.stdout):
        self.file_name = file_name
        self.stream = stream
        self.fp = None

    def __enter__(self):
        if self.file_name is None:
            return self.stream
        if self.file_name == "-":
            return self.stream

        self.fp = open(self.file_name, "w")
        return self.fp

    def __exit__(self, exc_typ, exc_val, exc_tb):
        if self.fp is not None:
            self.fp.close()


# =============================================================================


def main():
    """
    The main.
    """

    # Parse arguments
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument(
        "files",
        nargs="*",
        type=str,
        help="Input file(s) generated by segmaker")
    parser.add_argument(
        "-o",
        type=str,
        default="segbits.rdb",
        help="Output database file (def. segbits.rdb)")
    parser.add_argument(
        "-f",
        type=str,
        default=None,
        help="Tag filter. Processes only tags containing the specified text")
    parser.add_argument(
        "-t", type=float, default=0.95, help="Candidate threshold (def. 0.95)")
    parser.add_argument(
        "-e",
        type=float,
        default=0.1,
        help="RMS error threshold below which a tag is rejected (def. 0.1)")
    parser.add_argument(
        "-a",
        type=float,
        default=0.01,
        help="Regularization coefficient (def. 0.01)")
    parser.add_argument(
        "--all",
        action="store_true",
        help="Solve all tags at once (may give worse results)")
    parser.add_argument(
        "-x",
        type=str,
        default=None,
        help="A CSV file name to Write the numerical solution to")
    parser.add_argument(
        "-r",
        type=str,
        default=None,
        help=
        "A text file name to write bit correlation report to. Specify '-' for stdout"
    )

    parser.add_argument(
        "-m",
        type=str,
        default=None,
        help="Mask bits found for this feature in all other features")

    parser.add_argument("-b", type=float, default=0.0, help="Bias")

    parser.add_argument("-no_0", action="store_true", help="Do not output 0s")
    parser.add_argument("-no_1", action="store_true", help="Do not output 1s")

    args = parser.parse_args()

    # Build (baseaddr, offset) -> tile name map
    database_dir = os.path.join(
        os.getenv("XRAY_DATABASE_DIR"), os.getenv("XRAY_DATABASE"),
        os.getenv("XRAY_FABRIC"))
    tilegrid_file = os.path.join(database_dir, "tilegrid.json")
    address_map = build_address_map(tilegrid_file)

    # Compute threshold
    th = args.t

    # Load and filter segdata
    segdata = []

    def tagfilter(tag):
        if args.f is None:
            return True
        return args.f in tag

    for name in args.files:
        print(name)
        segdata.extend(load_data(name, tagfilter, address_map))

    # Make list of all bits
    all_bits = set()
    for seg in segdata:
        all_bits |= set(seg["bit"])
    all_bits = sorted(list(all_bits), key=sort_bits)

    # Detect bits that are always set
    const1_bits = set(all_bits)
    for seg in segdata:
        const1_bits &= set(seg["bit"])

    # Make list of all tags
    all_tags = set()
    for seg in segdata:
        all_tags |= set([tag[0] for tag in seg["tag"]])
    all_tags = sorted(list(all_tags))

    # Count 0s and 1s for each tag
    tag_count = {}
    for seg in segdata:
        for tag, val in seg["tag"]:

            if tag not in tag_count:
                tag_count[tag] = [0, 0]

            if val > 0:
                tag_count[tag][1] += 1
            else:
                tag_count[tag][0] += 1

    # Identify const0 and const1 tags
    const_tags = {}
    for tag in all_tags:
        if tag_count[tag][0] == 0:
            const_tags[tag] = 1
        if tag_count[tag][1] == 0:
            const_tags[tag] = 0

    const0_tags = [t for t, v in const_tags.items() if v == 0]
    const1_tags = [t for t, v in const_tags.items() if v == 1]

    # Print config
    print("# segs:", len(segdata))
    print("# tags:", len(all_tags))
    print("# bits:", len(all_bits))
    print("threshold: %.2f" % th)

    if len(segdata) == 0:
        print("No data!")
        exit(-1)

    if len(all_tags) == 0:
        print("No tags!")
        exit(-1)

    if len(all_bits) == 0:
        print("No bits!")
        exit(-1)

    if len(const1_bits):
        print("const 1 bits: " + ", ".join(const1_bits))

    if len(const0_tags):
        print("const 0 tags: " + ", ".join(const0_tags))
    if len(const1_tags):
        print("const 1 tags: " + ", ".join(const1_tags))

    # Data to solve
    tags_to_solve = list(all_tags)
    bits_to_solve = list(all_bits)

    for tag in const_tags.keys():
        tags_to_solve.remove(tag)
    for bit in const1_bits:
        bits_to_solve.remove(bit)

    # Statistics
    tag_stats = compute_tag_stats(tags_to_solve, segdata)

    # Solve
    print("Solving...")
    if args.all:
        X, E = solve_tichonov(
            tags_to_solve, bits_to_solve, segdata, bias=args.b, a=args.a)
    else:
        X, E = solve_onebyone(
            tags_to_solve,
            bits_to_solve,
            segdata,
            solver=solve_tichonov,
            bias=args.b,
            a=args.a)

    # Detect candidate bits
    W, X = detect_candidates(X, th, norm="max_abs")

    # Mask
    if args.m is not None:
        print("Masking out %s" % args.m)
        tags = [t for t in tags_to_solve if args.m in t]
        for tag in tags:
            i = tags_to_solve.index(tag)
            for r in range(len(tags_to_solve)):
                if r == i:
                    continue
                for c in range(len(bits_to_solve)):
                    if W[r, c] == W[i, c]:
                        W[r, c] = 0

    # Reject 0s and/or 1s
    if args.no_0:
        W[W < 0] = 0
    if args.no_1:
        W[W > 0] = 0

    # Reject tags with error greater than threshold
    for r in range(X.shape[0]):
        if E[r] > args.e:
            W[r, :] = 0

    # Compute correlation
    C, correlation_exceptions = compute_bit_correlations(
        tags_to_solve, bits_to_solve, segdata, W)

    # Write segbits
    write_segbits(args.o, tags_to_solve, bits_to_solve, W)

    # Dump to CSV
    if args.x is not None:
        with OpenSafeFile(args.x, "w") as fp:
            dump_solution_to_csv(fp, tags_to_solve, bits_to_solve, X)

    # Dump results
    dump_results(sys.stdout, tags_to_solve, bits_to_solve, W, X, E, tag_stats)

    # Dump correlation report
    if args.r is not None:
        if args.r != "-":
            print("Dumping bit correlation report to '{}'".format(args.r))
        with FileOrStream(args.r, sys.stdout) as fp:
            dump_correlation_report(
                fp, tags_to_solve, bits_to_solve, W, C, correlation_exceptions)


# =============================================================================

if __name__ == "__main__":
    main()
